/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved.
 *
 * @brief load data instruction ut for ascend910B1
 *
 */
#include <gtest/gtest.h>
#include "kernel_operator.h"
#include "lib/matmul/tiling.h"
#include "impl/matmul/matmul_call_back.h"
#include "impl/matmul/modules/matmul_module.h"
#include "impl/matmul/modules/matmul_policy.h"
#include "impl/matmul/modules/matmul_private_modules.h"

using namespace std;
using namespace AscendC;
using namespace matmul;

using A_TYPE = matmul::MatmulType<AscendC::TPosition::TSCM, CubeFormat::ND, half>;
using B_TYPE = matmul::MatmulType<AscendC::TPosition::TSCM, CubeFormat::ND, half>;
using C_TYPE = matmul::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float>;
using BIAS_TYPE = matmul::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float>;

template <typename IMPL, class INPUT_TYPE, const auto& MM_CFG>
class CustomCopyCubeIn {
public:
    void Reset() {
            clearedCount++;
    }

public:
    uint32_t clearedCount {0};
};

namespace {
template <const auto& MM_CFG, typename IMPL, typename A_TYPE, typename B_TYPE, typename C_TYPE, typename BIAS_TYPE>
class CustomMatmulPolicy : public matmul::MatmulPolicy<MM_CFG, IMPL, A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE>
{
public:
    using CopyCubeInA = CustomCopyCubeIn<IMPL, MatmulInputAType<A_TYPE, typename A_TYPE::T>, MM_CFG>;
    using CopyCubeInB = CustomCopyCubeIn<IMPL, MatmulInputBType<B_TYPE, typename A_TYPE::T>, MM_CFG>;
};

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const MatmulConfig& MM_CFG,
          class MM_CB = MatmulCallBackFunc<nullptr, nullptr, nullptr>, MATMUL_POLICY_DEFAULT_OF(MatmulPolicy)>
class MatmulImpl :
    MATMUL_IMPORT_MODULE_PRIVATE(IterateController),
    MATMUL_IMPORT_MODULE(CopyCubeInA),
    MATMUL_IMPORT_MODULE(CopyCubeInB)
{
    using VAR_PARAMS =
        typename MatmulParams<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, GetMatmulVersion(MM_CFG)>::PARAMS;

    MATMUL_ALLOW_USING_PRIVATE(IterateController);
    MATMUL_ALLOW_USING(CopyCubeInA);
    MATMUL_ALLOW_USING(CopyCubeInB);

public:
    using IterateController::MoveNext;
    using IterateController::Reset;

public:
    MatmulImpl() {
        InitVar();
    }

    void SetTiling(IterateOrder order, int32_t stepM, uint32_t stepN) {
        tiling.iterateOrder = static_cast<int32_t>(order);
        tiling.stepM = stepM;
        tiling.stepN = stepN;

        this->Reset();
    }

    void SetMParams(int32_t curPos, int32_t iter, int32_t stepIdx,  int32_t curStep) {
        var.curM_ = curPos;
        var.mIter_ = iter;
        var.stepMIdx_ = stepIdx;
        var.curStepM_ = curStep;
    }

    void SetNParams(int32_t curPos, int32_t iter, int32_t stepIdx,  int32_t curStep) {
        var.curN_ = curPos;
        var.nIter_ = iter;
        var.stepNIdx_ = stepIdx;
        var.curStepN_ = curStep;
    }

    VAR_PARAMS& GetVar() {
        return var;
    }

    void InitVar() {
        var.tiling_.SetTiling(&tiling);
        var.tpipe_ = &pipe;
    }

private:
    TCubeTiling tiling;
    TPipe pipe;
    VAR_PARAMS var;
};
}

class test_matmul_iterator_controller : public testing::Test {
protected:
    void SetUp() {}
    void TearDown() {}

private:
    MatmulImpl<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, CFG_NORM, MatmulCallBackFunc<nullptr, nullptr, nullptr>,
               CustomMatmulPolicy>
        mm;
};

TEST_F(test_matmul_iterator_controller, first_iter_order_M) {
    mm.SetTiling(IterateOrder::ORDER_M, 4, 2);
    mm.SetMParams(0, 4, 0, 0);
    mm.SetNParams(0, 2, 0, 0);
    
    bool isFinished = mm.MoveNext();

    ASSERT_TRUE(isFinished);

    ASSERT_EQ(mm.GetVar().curStepM_, 4);
    ASSERT_EQ(mm.GetVar().curM_, 0);
    ASSERT_EQ(mm.GetVar().curN_, 0);
}

TEST_F(test_matmul_iterator_controller, first_iter_order_N) {
    mm.SetTiling(IterateOrder::ORDER_N, 4, 2);
    mm.SetMParams(0, 4, 0, 0);
    mm.SetNParams(0, 2, 0, 0);
    
    bool isFinished = mm.MoveNext();

    ASSERT_TRUE(isFinished);

    ASSERT_EQ(mm.GetVar().curStepN_, 2);
    ASSERT_EQ(mm.GetVar().curN_, 0);
}

TEST_F(test_matmul_iterator_controller, order_M_iter_four_times) {
    mm.SetTiling(IterateOrder::ORDER_M, 4, 2);
    mm.SetMParams(0, 4, 0, 0);
    mm.SetNParams(0, 2, 0, 0);
    int32_t cnt = 0;
    while(mm.MoveNext()) {
        cnt++;
    }

    ASSERT_EQ(cnt, 8);
}

TEST_F(test_matmul_iterator_controller, order_N_iter_four_times) {
    mm.SetTiling(IterateOrder::ORDER_N, 4, 2);
    mm.SetMParams(0, 4, 0, 0);
    mm.SetNParams(0, 2, 0, 0);
    int32_t cnt = 0;
    while(mm.MoveNext()) {
        cnt++;
    }

    ASSERT_EQ(cnt, 8);
}


TEST_F(test_matmul_iterator_controller, order_M_iter_twice) {
    mm.SetTiling(IterateOrder::ORDER_M, 4, 2);
    mm.SetMParams(0, 1, 0, 0);
    mm.SetNParams(0, 2, 0, 0);
    auto isFinished = mm.MoveNext();
    ASSERT_EQ(mm.GetVar().curN_, 0);
    isFinished = mm.MoveNext();
    ASSERT_EQ(mm.GetVar().curN_, 1);
    ASSERT_TRUE(isFinished);
    isFinished = mm.MoveNext();
    ASSERT_FALSE(isFinished);
    ASSERT_EQ(mm.GetVar().curM_, 0);
}

TEST_F(test_matmul_iterator_controller, order_N_iter_twice) {
    mm.SetTiling(IterateOrder::ORDER_N, 4, 2);
    mm.SetMParams(0, 2, 0, 0);
    mm.SetNParams(0, 1, 0, 0);
    auto isFinished = mm.MoveNext();
    ASSERT_EQ(mm.GetVar().curM_, 0);
    isFinished = mm.MoveNext();
    ASSERT_EQ(mm.GetVar().curM_, 1);
    ASSERT_TRUE(isFinished);
    isFinished = mm.MoveNext();
    ASSERT_FALSE(isFinished);
    ASSERT_EQ(mm.GetVar().curN_, 0);
}

// test when n-dimension is finished in OrderM case
TEST_F(test_matmul_iterator_controller, order_M_n_is_finished) {
    mm.SetTiling(IterateOrder::ORDER_M, 4, 2);
    mm.SetMParams(0, 2, 0, 0);
    mm.SetNParams(0, 2, 0, 0);
    // first iter
    auto isFinished = mm.MoveNext();
    ASSERT_EQ(mm.GetVar().curM_, 0);
    // n-dimension is finished
    isFinished = mm.MoveNext();
    ASSERT_EQ(mm.GetVar().curN_, 1);
    ASSERT_TRUE(isFinished);
    (void)mm.MoveNext();
    ASSERT_EQ(mm.GetVar().curN_, 0);
    ASSERT_EQ(mm.GetVar().curM_, 1);
}

// test when m-dimension is finished in OrderN case
TEST_F(test_matmul_iterator_controller, order_N_m_is_finished) {
    mm.SetTiling(IterateOrder::ORDER_N, 4, 2);
    mm.SetMParams(0, 2, 0, 0);
    mm.SetNParams(0, 2, 0, 0);
    // first iter
    auto isFinished = mm.MoveNext();
    ASSERT_EQ(mm.GetVar().curN_, 0);
    // n-dimension is finished
    isFinished = mm.MoveNext();
    ASSERT_EQ(mm.GetVar().curM_, 1);
    ASSERT_TRUE(isFinished);
    (void)mm.MoveNext();
    ASSERT_EQ(mm.GetVar().curM_, 0);
    ASSERT_EQ(mm.GetVar().curN_, 1);
}
